{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch \n",
    "import matplotlib.pyplot as plt\n",
    "import robust_loss_pytorch.general\n",
    "\n",
    "# Construct some regression data with some extreme outliers.\n",
    "np.random.seed(1)\n",
    "n = 50\n",
    "scale_true = 0.7\n",
    "shift_true = 0.15\n",
    "x = np.random.uniform(size=n)\n",
    "y = scale_true * x + shift_true\n",
    "y += np.random.normal(scale=0.025, size=n)\n",
    "flip_mask = np.random.uniform(size=n) > 0.9\n",
    "y = np.where(flip_mask, 0.05 + 0.4 * (1. - np.sign(y - 0.5)), y)\n",
    "\n",
    "x = torch.Tensor(x)\n",
    "y = torch.Tensor(y)\n",
    "\n",
    "class RegressionModel(torch.nn.Module): \n",
    "    # A simple linear regression module.\n",
    "    def __init__(self): \n",
    "        super(RegressionModel, self).__init__() \n",
    "        self.linear = torch.nn.Linear(1, 1)\n",
    "    def forward(self, x): \n",
    "        return self.linear(x[:,None])[:,0]\n",
    "\n",
    "def plot_regression(regression):\n",
    "    # A helper function for plotting a regression module.\n",
    "    x_plot = np.linspace(0, 1, 100)\n",
    "    y_plot = regression(torch.Tensor(x_plot)).detach().numpy()\n",
    "    y_plot_true = x_plot * scale_true + shift_true\n",
    "\n",
    "    plt.figure(0, figsize=(4,4))\n",
    "    plt.scatter(x, y)\n",
    "    plt.plot(x_plot, y_plot_true, color='k')\n",
    "    plt.plot(x_plot, y_plot, color='r')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0   : loss=6.892803\n",
      "100 : loss=1.706846\n",
      "200 : loss=1.674322\n",
      "300 : loss=1.673696\n",
      "400 : loss=1.673694\n",
      "500 : loss=1.673694\n",
      "600 : loss=1.673694\n",
      "700 : loss=1.673694\n",
      "800 : loss=1.673694\n",
      "900 : loss=1.673694\n",
      "1000: loss=1.673694\n",
      "1100: loss=1.673694\n",
      "1200: loss=1.673694\n",
      "1300: loss=1.673694\n",
      "1400: loss=1.673694\n",
      "1500: loss=1.673694\n",
      "1600: loss=1.673694\n",
      "1700: loss=1.673694\n",
      "1800: loss=1.673694\n",
      "1900: loss=1.673694\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQcAAAD8CAYAAAB6iWHJAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3Xt8zuX/wPHXtdkRbQppJB0cckwkh0IiKkaoH5VOSlQiUghFk7N0UFIkcsipWYXJIedzk7PyXQ7bCLM5zOx4/f74bNy7fWb37L7v3ffu9/Px8Gi772uf+31v7b3rc33e1/ujtNYIIYQ1r8IOQAjhmiQ5CCFMSXIQQpiS5CCEMCXJQQhhSpKDEMKUJAchhClJDkIIU5IchBCmihXWC5cuXVpXqlSpsF5eCI+1c+fOM1rrMnmNK7TkUKlSJXbs2FFYLy+Ex1JKHbVlnJxWCCFMSXIQQpiS5CCEMCXJQQhhSpKDEMKUJAchhClJDkIIU5IchBCmJDkIIUwVWoVkYQiPimVc5CHiEpMJCQ5gQOuqdKhb3uWO6UjuFq8w54yfo8fMHMKjYhm0eA+xicloIDYxmUGL9xAeFetSx3Qkd4tXmDP7Ob49cTYLtv1r19fxmOQwLvIQyWkZOR5LTstgXOQhlzqmI7lbvMKc5c9Ra825LQs5/uNA3h0aZtfX8ZjTirjE5Hw9XljHdCR3i1eYy/556fQ04iO/JGnvKgKrPQw1n7Tr63hMcggJDiDW5JcgJDjApY7pSO4WrzBfWwgJDuBY3ElOLx5JSux+gpo8S1CTrlQoFWjX1/aY04oBrasS4OOd47EAH28GtK7qUsd0JHeL19PltkZUIyCR/2b2I/W/w5QOfY/gh54l0LeY3X+OHjNzyF7JtecKryOO6UjuFq+nM1sjij+4hWkRYwm6qSQVenzKhZJ3OOznqArrXpn169fX0uxFiNzdOfA3sn87tdZc2LGEhDXT8S17J5NnzOXVNg/c0HGVUju11vXzGucxpxVCuJvstSCdkcbZ5V+QsPo7Ais35NZnxzBhY7zDL0FLchDCRQ1oXRXftCT++2koF3ev4KZGz1C6w0C8fP2dcgnaY9YchHA31QIucGH+e6SciOOWtv0pUeORHM87+hK0TTMHpVQbpdQhpdRhpdRAk+crKqXWKKWilFK7lVJP2D9UITzHihUraNiwIZmpydR6bcI1iQEcfwk6z5mDUsobmAy0AmKA7UqpCK31fothQ4D5WuuvlVLVgaVAJQfEK4Rbsa5TeKRaGdYcPH3dq0Vffvklffv2pXr16vzyyy9EnS3GoMV7cly5cMYlaFtOKxoAh7XW0QBKqXlAe8AyOWjgpqyPg4A4ewYphDvKrlPI/qWOTUzmxy3HrjyfXbcAxmXmtLQ0+vbty1dffUW7du2YPXs2JUuW5I47jPHOvgRtS3IoDxy3+DwGeNBqzEfACqVUb6A40NLsQEqpHkAPgIoVK+Y3ViHcilmdgrXshcVmlQJ55plnWLlyJe+99x6ffPIJ3t5XC9Y61C3v9HoUW9YclMlj1sURXYEZWusKwBPALKXUNcfWWk/VWtfXWtcvUybPG+4I4dZsXTA8Gn2Yhg0bsnbtWr7//nvGjBmTIzEUFltmDjHA7RafV+Da04buQBsArfVmpZQ/UBo4ZY8ghXBHwYE+JFxKu+6Y5KN/cXbJaIICfVm1ahUPP/ywk6LLmy0zh+1AZaXUnUopX6ALEGE15hjwKIBS6l7AHzhtz0CFcDd5FR9f2LWMU/OHcdttt7Ft2zaXSgxgw8xBa52ulHoLiAS8gela631KqRHADq11BNAf+FYp9Q7GKcdLurDqsoVwEeeSzWcNOjODhNXTuLAzgrqNm7NmaThBQUFOji5vNhVBaa2XYlyetHxsmMXH+4Em9g1NCPdmtkU+MyWJC0vHc+Hv7bzzzjuMGzfOJdYXzEj5tBAOYr1FPi3xJP/9OIAL0VFMnTqViRMnumxiACmfFsJhLLfIR+/ZzpnwT/AvpohYsYJHHrm24tHVyMxBCAfqULc8r956lPgFQ7n79tv4c8d2t0gMIMlBCIfJyMhgwIABvPLKKzRr1ozNmzdTuXJlx7zY6tVw8aJdDynJQQgHuHDhAh06dGD8+PGUaxjKP3V70/abKPv3YDhxArp2hUcfhc8/t+uhZc1BCDs7evQo7dq1Y9/+/dzauhd+9xldoa33UhRIejp89RUMHQopKTB8OPTrV9DQc5CZgxB2tHnzZho0aMCxY8eo9sIn+N+Xs128XZq0bNkCDzwAffpAo0awdy8MGwb+/gU7rhVJDkLYyaxZs2jevDklS5Zky5YtXCpbw3TcDTdpOXsWXn8dGjeG06dh/nxYtgzuuacAUedOkoMQBZSZmcngwYN54YUXaNy4MVu3bqVatWq5NmPJd5MWrWHGDKhaFaZNg7594cABePppUGb7Iu1DkoMQBZCUlETnzp0ZNWoUr732GpGRkdxyyy2Ane4TsmcPNG0KL78MVarAzp0wcSKULGnPt2FKkoMQN+j48eM8/PDDLFmyhIkTJ/LNN9/g6+t75fkOdcszqmMtygcHoIDywQGM6ljLtsXIixdhwACoW9eYJXz3HaxfD3XqOO4NWZGrFULcgG3bttG+fXuSkpL45ZdfeOIJ87ap+W7SojX8/LOx2BgTA927w5gxkDUbcSaZOQiRTz/99BPNmjXD39+fTZs25ZoY8i06Gtq2hU6d4OabYeNGY8ZQCIkBJDkIYbPMzEw+/PBDunTpQv369dm2bRs1a9Ys+IFTUuDjj6FGDVi3zlhT2LnTuCpRiOS0QggbXLp0iZdffpn58+fz0ksvMWXKFPz8/Ap+4JUr4Y034J9/jKsPn34K5V3j3qUycxAiD3FxcTRv3pwFCxYwZswYpk+fXvDEEBcHXbpAq1bGOsPy5UbdgoskBpCZg/BQ1veTyK3V+59//kloaCiJiYmEh4cTGhp6Q8e5IrvsecgQSE2FDz+EgQPtXt1oD5IchMcxu5+E2Z6HxYsX061bN0qXLs3GjRupY3UZ0dbjXLF1K/TsCbt2wWOPweTJDqtutAc5rRAex+x+EpZ7HrTWfPLJJ3Tq1Ik6deqwbdu2axKDLce5IrvsuVEjOHXKOH1YvtylEwPIzEEUEfmZ3ue2tyEuMZnLly/z6quvMnv2bJ577jm+++47/HOZ8l/vOICxlvDDD0YxU0ICvPMOfPSRU6ob7UGSg3B7+Z3emzV+BSjtnUyLFi3YvHkzYWFhDB48GKVUroknt+OEBAcYOyXfeMOoamzUCKZMgdq17fzOHUtOK4Tbs3l6n8Vsz4M6e4yj0/uya9cuFi5cyAcffHAlMQxavIfYxGQ0VxNPeFSs6XFu0anMOLDAKHvet88oYtqwwe0SA8jMQbix7L/oZn+9Ifdpv2Xj17jEZPzjojiycBQ3lwpm/fr11KtX78rY6yWejQNbXD1OwiW6xP3J0JXfEHgyzih7Hj0aSpe2x1stFJIchFuyPpUwExzok+tzHeqWp/19IYwfP573xw7j/vvvJyIigpCQkBzj8lpX6FC3PB2CUqB3b1i61JghLFpQ6NWN9iCnFcIt2XIH68vXeT4lJYXu3bvz3nvv8fTTT7Nu3bprEgPk3nshJDjAKHsOC3O5smd7keQg3JIt3ZSS0zJNG7qeOXOGVq1a8f333zNs2DDmzp1LYGCg6TFy68kwNviUMUsYOhTatYODB42rEcWKzmRckoNwS7Z2U7JelNy/fz8NGjRg+/btzJ07l+HDh+PllfuvgXVPhtpel/h9+xSa9OoKGRkuWfZsL0UnzQmPMqB11TzXHCDnDGPZsmV06dKFwMBA1q5dS4MGDa77tZaXMG8v6cuSzJ3U/nqCcTrx0Ufw/vsuWfZsL5IchFuyvuIAxu3drQUH+qC15rPPPqN///7Url2biIgIbr/99use33LB8764Q4St+Iqa//2P/xo149aZ37l8daM9SHIQbsuyy9J9w1eQaHLL+8z0dHr27MnUqVPp0KEDs2bNokSJEnkee1zkIXzPJzJ07Q90+SuSUyVK0av9QHY/2JKNHpAYQJKDKCLOmSSGjOQL/D33E3Yf28OgQYMICwu77vrCFVrTaP0vDFoznaDLF5n2QHsmNXmWJL9A1LnLDojeNUlyEEWCdSlzWnwMpxYNJ+P8GfqMmMRan1rcPXhZ3tuq9+6FXr0Yv2EDO0OqMaT1Gxwoe1eO1/EUcrVCFAmWlxyTj+zixKz+6JRLvBA2jcjUqqblzzlcvAjvvWeUPe/fT9SwcXR7eUKOxJDvtvJuTpKDKBKyLzl6HVjBqfnDCAguy1fzI/mH8tffd5Hd7bl6dRg3Dl54AQ4dou7wd/mkU50baytfRMhphSgS0tPTWTV9NP9GfEnbtm2ZM2cOJUuWZNT630zHxyUmG92eLcue583LUd2Y77byRYwkB+H2EhMTeeaZZ/j999959913GT16NN7eximG2bZq3/Q03t0dATXmGhWNEybA228XqepGe5DvhnBrhw8fpm3btkRHRzNt2jReeeWVHM9bF0s1PrKLkb9/zZ1nY41uzxMnQoUKhRG6y5PkINzWH3/8QceOHfHy8mLlypU0bdr0mjHZpwXTF2yie/hk2h9Yy8UKd8DsZdCmjbNDdiuyICnc0rfffkurVq0oV64cW7duNU0MAKSn02HDYiImv0r76C0wbBgl/j4gicEGNiUHpVQbpdQhpdRhpdTAXMY8o5Tar5Tap5SaY98whTBkZGTQr18/evTowaOPPsrmzZu5++67zQdv3QoNGhjrCQ0bGnesHj4cAjynVqFAtNbX/Qd4A/8D7gJ8gb+A6lZjKgNRQKmsz8vmddx69eppIfLj3Llz+oknntCAfvvtt3VaWpr5wPh4rXv00FoprUNCtJ4/X+vMTOcG68KAHTqP30+ttU1rDg2Aw1rraACl1DygPbDfYsxrwGStdUJWwjlV4KwlhIXo6GhCQ0M5dOgQU6ZM4fXXX792kNYwc6bR7fnsWejb19g9edNNTo+3KLAlOZQHjlt8HgM8aDWmCoBSaiPGTOMjrfVy6wMppXoAPQAqVqx4I/EKD7RhwwaeeuopMjIyWL58OY8++ui1g6y7PX/9NZjca0LYzpY1B2XymPXu2GIYpxbNga7Ad0qp4Gu+SOupWuv6Wuv6ZcqUyW+swgPNmDGDFi1acPPNN7Nly5ZrE4Nl2bNlt2dJDAVmS3KIASw3v1cA4kzGLNFap2mt/wUOYSQLIW5IRkYG77//Pi+//DJNmzZly5YtVKlS5eqAXMqe6d4dbNl5KfJky3dxO1BZKXWnUsoX6AJEWI0JBx4BUEqVxjjNiLZnoMJzXLx4kY4dOzJ27Fh69uzJsmXLKFWq1NUB0dHQti107AjBwcZMYdo0t24D74ryTA5a63TgLSASOADM11rvU0qNUEpl33I4EohXSu0H1gADtNbxjgpaFF3Hjh3joYce4tdff+WLL77g66+/xscnq8V8SgqMHHm12/OECfDnn9CkSeEGXVTZcknDEf/kUqawtnnzZl22bFkdFBSkly9fnvPJlSu1rlpVa9C6c2etjx8vnCCLAGy8lCknZ8IlzJ49m+bNm1OyZEk2b95M69atjSdOnIBnn4WWLSE9HZYtgwULZD+EE0hyEIUqMzOTIUOG8Pzzz9OwYUO2bt3KvffeaySCzz+HatVg8WL48EPjcqWUPTuNbLwSDpfbXaqTkpJ48cUXWbRoEd27d+err77C19cXtm2Dnj0hKgoeewy+/BIqy8UvZ5PkIBzK+p6W2W3a4k+d4KvBr7Nr1y4mTpxI3759UYmJxj6IqVPhttuMm8V07gzKrNRGOJokB+FQZve0TDx2gF5Pj8SfVCIiInjyiSeulj3Hx0OfPsYGKSl7LlSSHIRDWd/TMunAeuKXfopXYDCbNm6iplLQrNnVsucVK+C++wopWmFJFiSFQ2W3ctdak7hxLmcixuB76z00fn0CNWfONBKBZdmzJAaXIclBONSA1lXxI50zv4zj3IbZFK/xCN0bhLJ81vtS9uzi5LRCONSD5bzQv33EpYN/UadBJyafi6HJz6OhVi1YOF+qG12YJAfhMFFRUYSGhnIxPp59XbtSffHiq92ee/eG7LJo4ZIkOQi7saxn8IvdyZGFo+lQojgzbr0V/zlzpNuzm5HkIOwiu57hUmo657cuxH/tD/wQUJL/O3MGgoKMsmepbnQrsgIk7GJc5CEuJV8m4dcJPL/2Bw56edMhNZnpLboZjV0lMbgdmTkIuzgee4KKPw3h59NHuB9YV7EWw1r14ujN5XlFuj27JUkOosD2bdjA0Kmv0j31Mv/5l+CN1m+xtGoTUApvpQiPivXoe066K0kO4sZpTVS/flT47DNe1prpNZozqdUbXPQLvDIkQ2sGLd4DIAnCzciag7gheu9ejt9zD3UnTSIuMJCzkZGUnvUjyf7Frxmb45b3wm1IchD5c/EiGe++S0bt2gRGRzPl/vupfPIkZR97jA51y5OprRuTG6z3WAjXJ8lB2EZrCA8no1o1vCdMYKbWfNu/Pz22byewRIkrw7L3UljL7XHhuiQ5CMCoU2gyejV3DvyNJqNXEx4Ve/XJf/+Fdu3gqaf45/RpHvHxwX/2bAaOH4+X1X6IAa2rEuDjneOxAB9vBrSu6oy3IexIFiRFrg1ZvFJTCF05F8LCSFeKYX5+zAwKYlFEBA8+aH3TM0P2oqNZ5yfhXiQ5CNOGLHX/2Umddj3g9HH+ue8+Hv3rL26uXZtNERF53sqwQ93ykgyKADmtEDkWC8tcPMtnEeOY89MQSE3li8cfp8quXdwfGsqGDRvkHqceRGYOgpDgAE6cvUi3qKX0XzcLv4xUPm3QkTGn/8d/y5YxcOBARo4cec36gijaJDkIRlZIpsxn/alx8jDrKtVl8AMd2LZqKvr8KWbMmMGLL75Y2CGKQiDJwZMlJMDgwTT/5huSbynL0K5D+TbDlzMRYyju78tva1bz0EMPFXaUopDIPNETaW10e65a1WgD36cPAf/7m/JNQziz6CPuvacSu/7cIYnBw8nMwdPs2wdvvGHciLZhQ1ixgvSaNenXrx9ffPEFTz75JHPmzOEmaQvv8SQ5eIqkJBgxgsyJE7noE8Co1m+xvml7ep33ZeaTT7JixQr69evH2LFj8fb2zvt4osiT5FDUaQ0REcadpI4d4+c6jzGy6YucDQwi7egRXpn4ChnnTvDtt9/y6quvFna0woVIcijK/v3XSAq//gq1atHz9c9YHnw3AJeP7+X0z5+A1tz53CeSGMQ1ZEGyKEpJgZEjoXp1WLMGxo+HnTuJzEoMF/5awX/zhuAVcBPlXphAernqOfdSCIEkh6Jn1SqoUweGDIEnn4SDB6F/f/Dx4babfElYPY2zyz/Hv2Itbus2Hp9SIQDSb0FcQ04riooTJ4wkMHcu3HUXLF0Kjz9+5enz58+TumwM57evomS9dpRq8SrK6+rCo/RbENZk5uDuMjLgiy+gWjVYtAiGDoW9e68khvCoWOq9P4fSd9dm56Y1lG3zJje3fD1HYgAIDpQbzIicZObgzrZtg1694M8/oVUrmDwZKle+8nR4VCx9PptHzIKPISOdsk+PIKCS+Y1qc2ngJDyYzBzcUUKCkRQaNjROJ376CSIjcyQGgAGjvuDYjwPx8itOuW4Tck0MAOeS0xwdtXAzMnNwcZa3mAsJ8ufz9L3U+2IkOj6eBY2e4uMGXbgp+hYG7Iq70kMhMzOTDz74gMMLxuB/R21Ktx+Ed0DJ676OtHET1mxKDkqpNsBngDfwndZ6dC7jOgMLgAe01jvsFqWHsuzQVPn0UT6e8zX1ju/leJVavN3+Q6JuqQTAhazOTQAtKwfRrVs3wsPDubVBW/yavoryzvljVoDlWYS0cRNm8kwOSilvYDLQCogBtiulIrTW+63GlQTeBrY6IlBPNC7yECQlMXDTXLpvD+eibyDvt+nNojqPkY7KMTY5LYOwn9YzPHIMu3fvZtKkSVR8uBODf96bo8tTgI83neqVZ83B09LGTVyXLTOHBsBhrXU0gFJqHtAe2G817mNgLPCuXSP0VFpTY9tqPlw5lfIXTvNTrVaMbv4SCYFBpsNT4g6xa3EYxb0z+O2332iTdW9KpZT0cxQ3xJbkUB44bvF5DJCju6hSqi5wu9b6V6WUJIeCyip7nvrrrxwsfQedQ8ewo0KNK097K0WGxeWFpP1rObN0En43lWbzhpVUr179ynPSz1HcKFuSgzJ57Mr/mUopL+BT4KU8D6RUD6AHIL0IzaSkwIQJEBYGXl7s7TuErgEPciHz6o8g+7Rg0c5YLqWmcW7DXM5tmkvA7TWZMmNOjsQgREHYcikzBrjd4vMKQJzF5yWBmsAfSqkjQEMgQilV3/pAWuupWuv6Wuv6ZcqUufGoi6LVq42y5w8+MAqYDhyg5qcf8/HTdSkfHIACygcHMKpjLcI61OKjx+8hadkEzm2aS9n6bfhhYQQvtKhV2O9CFCVa6+v+w5hdRAN3Ar7AX0CN64z/A6if13Hr1aunhdb6xAmtn31Wa9D6rru0Xro0zy+JiYnR9erV00opPX78eJ2ZmemEQEVRAezQefx+aq3zPq3QWqcrpd4CIjEuZU7XWu9TSo3IepEIRyStIi8jA77+2pgpXL4Mw4bBwIGEHzzLuNGrc11A3LlzJ6GhoZw/f54lS5bQrl27QnwToiizqc5Ba70UWGr12LBcxjYveFhFnGXZc8uWRtlzlSq53nkKjIXFhQsX8sILL1CmTBk2btxI7dq1C/NdiCJOyqedybrsed48WLECqlQBzO88lZyWwdjlBwkLC+Ppp5+mbt26bN++XRKDcDgpn3YGreHHH+Hdd+HMGaM704gRYNXE1WzbtE5P5a8fx7F5/1rK1G1FnwlfU7ZsWWdFLjyYJAdHs+72HBkJ95lvgAoJDiDWIkFkXEzg1M9hpMYdIrjpCwQ0fJoPf/sHXz9/qV0QDienFY6SlATvv28kgj17jPtDbNxomhjCo2JpMnp1jsSQeiqaEzP7kXb6CGWeGkxQo2dQSpGcliFdm4RTyMzB3rSGJUugTx84dgxefhnGjIFc6jqsFyEBLv29mTO/jsfLvyTlnhuL76135/ga6doknEGSgz1ZdnuuWRPWr4c87hpluQipteb81kUkrv0B39sqU6bjEIqVuPmar5Ht1cIZJDnYg1XZM+PGGTMHn7xbr2XPAnR6GvGRX5C0dzWB1R7mlif64uXjd8142V4tnEWSQ0GtXm0sOB46BB07wqRJcPvteX9dlpDgAI7FnuD0z5+QErufoCbPEtSkK0oZ+ymCA3xQChIvpcmuSuFUkhxu1MmTRrfnOXNMuz3b6pm7Nf1H9yczKYHSoe9R/N6mV55TwK4PH7Nj0ELYTpJDfuVS9kxA/tcBfvvtN4a80gWlfLn12dH43VYlx/OytiAKk1zKzI/t26FBA+jd2/jvnj0wfHi+E4PWmk8//ZTQ0FAqV67M1/OXE1zx3hxjZG1BFDaZOdgiIcGYKUyZAuXKwdy5hFd5iHGL/yYu8Z98rQWkpqby5ptv8t1339GxY0dmzpxJ8eLFKX1rrHRsEi5FksP1aA2zZhllz/HxV8qew/934bobpHITHx9Pp06dWLt2LR988AEjRozAy8vrytdJMhCuRJJDbvbvN65CrF1rlD2vWHGlunFc5A7TDVLjIg/l+gt+4MAB2rVrR0xMDD/++CPPPfecw9+CEAUhaw7WkpKMBcY6dWD3btOy59wqFHN7fMWKFTRq1IgLFy6wZs0aSQzCLUhysLRkiXHb+jFjoFs3o3bhtdeMwiYLuV1FsH5ca82XX37J448/TsWKFdm2bRuNGjVyWPhC2JMkB4AjRyA0FDp0MLZRr1sH06fnuh9iQOuqBPjkvBGt9dWFtLQ03nzzTXr37s2TTz7Jxo0bueOOOxz5LoSwK89ec0hNhfHj8132nL2uYHl14ZFqZRgXeYh3ftpFWb90Li8fz+5tGxgwYACjRo3C29v7uscUwtV4bnJYs8ZYcDx48IbKni2vLljurEw7G8ufi0aQnvgfvT+ayNgP33HUOxDCoTzvtOLkSXj+eWjRwpg5LF0KixblKzFYy95Zefnobk7O6k9m8gVu7RLGTr86dgxcCOfynJmDddnz0KEwaNANlT1bi0tM5sKuZZz9fQo+pcpTpvMwfILLSd8F4dY8Izls3240dt25M0e3Z3tIT08nZf00zm76Gf8761Gm/Xt4+RUHZG+EcG9F+7QiIcFYV3jwQYiLg7lzc3R7Lqhz587Rtm1bTmz6mVINOlC287AriUH2Rgh3VzSTQ3bZc7Vq8M03xkapgwehSxdQZrf+zL/o6GgaN27MqlWrmDp1KtOnfEmFm0vkuG2dlEMLd1b0Tissy54ffBCWL4e6de36EuvWraNjx45orfn9999p3rw5cP19FUK4m6Izc7Aue/7mG9i0ye6JYfr06bRs2ZLSpUuzdevWK4lBiKKmaMwcliwxdkza0O3ZFuFR126fble7HAMHDmT8+PG0atWK+fPnExwcbMc3IYRrce+ZQz7Lnm2RXdAUm5iMxtiO/d7crTRq8Tjjx4/nzTffZOnSpZIYRJHnnjOHGyx7toX1/SrTz50ibtEI0uKP8dr7YUQFN6LykEhpyCKKPPdLDgUse86LZeHS5ZgDnP55JDojjbKdh7Pepx7JWc/b2uBFCHflPqcVlmXPKSnw228FLns2k124dHHfGv6bNwgvvwBu6zaeEnfdn2uDFyGKIveYOaSkQL16xh2q7Vj2bC08Kpaky6kkrJvJ+c3z8atYmzIdBlHipuBrEkM2KZEWRZV7JAc/P/jsM6hd2y7VjWZXIwDen7edmPBxXPp7EyXqtObmVr0oVSKAj0JrMC7yUI4b3WaTEmlRVLlHcgDo3Nkuh7G+cW322oHXpTP8++OHpJ0+QqkWr1GyfihKKYr7FbuypmB9w1spkRZFmfusOdiJ9dUIgMSj+zn0zdukJ56gbKeh3PRA+yu3o4tNTKbJ6NUAjOpYi/LBAVIiLTyC+8wc7MR6jSDpwDril07Cu3gpyv1fGL5lrm3llj27GNWxFhsHtnBWqEIUKo/hyjByAAAJZklEQVSZOYRHxdJk9Gp01udaaxI3zOZMxFh8y91DpVcmERRyV65fL1cmhKfxiJmD9TpDZloK8UsncengeorXbMktrd8EP1861SvPmoOnTRceQa5MCM/iETMHy3WG9Avx/Dd3IJcObiC4+Uvc8kQfVDEf0jI0aw6eZuPAFpS3sfW8EEWZTclBKdVGKXVIKXVYKTXQ5Pl+Sqn9SqndSqlVSimX6sGe/Rc/5eRhTs7sR9qZ45TpOISgBztfWXi0HGdL63khiro8TyuUUt7AZKAVEANsV0pFaK33WwyLAuprrS8ppXoBY4H/c0TAZszqFiyvIoQEB/DP1lWc+XUCXoE3Ue75sfiWvXZ9IXtmYNZ6XvZRCE9jy5pDA+Cw1joaQCk1D2gPXEkOWus1FuO3AM/bM8jrya1uAYxfcq01lY5Hsil8LL4hVSnbcQjexUtdcxzrmYHc2FZ4OltOK8oDxy0+j8l6LDfdgWUFCSo/zOoWsq8sXL58mW7dujHnq7E0ffwp7u85iWLFS1E+OIDnG1aUmgUhrsOWmYNZ00Vt8hhKqeeB+kCzXJ7vAfQAqFixoo0hXl9uVxCOx8bRokULNm/eTFhYGIMHD86xviCEuD5bkkMMYLn1sQIQZz1IKdUS+ABoprVOMTuQ1noqMBWgfv36pgkmv0KCA6659Jh66l/if/6YMykXWLhwIZ06dbLHSwnhUWw5rdgOVFZK3amU8gW6ABGWA5RSdYFvgFCt9Sn7h5k76ysLl/7ZyskfB1DS14v169dLYhDiBuU5c9Bapyul3gIiAW9gutZ6n1JqBLBDax0BjANKAAuypu7HtNahDow7xxWKoAAf/Iopjv0xj4S1P3DPvbVZ+/tSQkJCHBmCEEWaTRWSWuulwFKrx4ZZfNzSznFdl/UVioQLlzi38ivO/fU7zzzzDN9//z2BgYHODEmIIscty6ctr1BkXDrH6Z8/ISVmHxVadGPu3Bl4eXlE4acQDuWWySH7CkXq6aOcXjSCjKQESrcbQLHqzSQxCGEnbpkcQoIDOLxzPacjxuDl48+tXUfhF1JV9j4IYUdu92dWa021M+s5tWgExYJvo9wLE/ELqSp7H4SwM7eaOaSlpfHWW28xfepUGrZ4HN3sTf67hOx9EMIB3CY5xMfH07lzZ/744w8GDRpEWFiYrC8I4UBukRwuX75M48aNOXLkCDNnzqRbt26FHZIQRZ5bJAd/f3/69OnDfffdR+PGjQs7HCE8glvMy8OjYpl9vhrPRSTQZPRqwqNiCzskIYo8pbVd9j/lW/369fWOHTvyHGddDQng46Uo4V+MxEtpshgpRD4ppXZqrevnNc7lTyvM+jWkZWoSLqUBckNbIRzF5U8rbOn4LG3jhbA/l08OtlY9Stt4IezL5ZODWSdoM1I6LYR9ufyag3Un6KAAH5JS00nLuLqQKqXTQtifyycHuLYTdF6t6IUQBecWycGatI0XwvFcfs1BCFE4JDkIIUxJchBCmJLkIIQwJclBCGFKkoMQwpQkByGEKUkOQghTkhyEEKYkOQghTElyEEKYcsu9FULYQ2Fu4HOHzYOSHIRHsu5N6sx2g4X52vkhpxXCI5n1JnVWu8HCfO38kOQgPFJubQWd0W6wMF87PyQ5CI+UW1tBZ7QbLMzXzg9JDsIjmfUmdVa7wcJ87fyQBUnhkax7kzrzikFhvnZ+uPwdr4QQ9mXrHa/ktEIIYUqSgxDClCQHIYQpm5KDUqqNUuqQUuqwUmqgyfN+Sqmfsp7fqpSqZO9AhRDOlWdyUEp5A5OBx4HqQFelVHWrYd2BBK31PcCnwBh7ByqEcC5bZg4NgMNa62itdSowD2hvNaY98EPWxwuBR5VSyn5hCiGczZbkUB44bvF5TNZjpmO01unAOeAWewQohCgctiQHsxmAdXGELWNQSvVQSu1QSu04ffq0LfEJIQqJLRWSMcDtFp9XAOJyGROjlCoGBAFnrQ+ktZ4KTAVQSp1WSh3NZ7ylgTP5/BpnkvgKRuIrGFvju8OWg9mSHLYDlZVSdwKxQBfgWasxEcCLwGagM7Ba51F6qbUuY0uAlpRSO2yp7CosEl/BSHwFY+/48kwOWut0pdRbQCTgDUzXWu9TSo0AdmitI4BpwCyl1GGMGUMXewUohCgcNm280lovBZZaPTbM4uPLwNP2DU0IUZjcrUJyamEHkAeJr2AkvoKxa3yFtitTCOHa3G3mIIRwEpdMDq6+l8OG+PoppfYrpXYrpVYppWy6dOSs+CzGdVZKaaWU01bgbYlNKfVM1vdvn1JqjrNisyU+pVRFpdQapVRU1s/3CSfHN10pdUoptTeX55VS6vOs+Hcrpe6/4RfTWrvUP4wrIv8D7gJ8gb+A6lZj3gCmZH3cBfjJxeJ7BAjM+riXq8WXNa4ksA7YAtR3ldiAykAUUCrr87Ku9L3DOK/vlfVxdeCIs+LLes2mwP3A3lyefwJYhlGY2BDYeqOv5YozB1ffy5FnfFrrNVrrS1mfbsEoHHMWW75/AB8DY4HLLhbba8BkrXUCgNb6lIvFp4Gbsj4O4tqCQIfSWq/DpMDQQntgpjZsAYKVUrfdyGu5YnJw9b0ctsRnqTtGJneWPONTStUFbtda/+rEuMC2710VoIpSaqNSaotSqo3TorMtvo+A55VSMRiX93s7JzSb5ff/z1y5YoNZu+3lcBCbX1sp9TxQH2jm0IisXtbksSvxKaW8MLbVv+SsgCzY8r0rhnFq0RxjxrVeKVVTa53o4NjAtvi6AjO01hOUUo0wiv9qaq0zHR+eTez2u+GKM4f87OXgens5HMSW+FBKtQQ+AEK11ilOig3yjq8kUBP4Qyl1BOO8NMJJi5K2/myXaK3TtNb/AocwkoUz2BJfd2A+gNZ6M+CPsafBVdj0/6dNnLmYYuOCSzEgGriTq4tCNazGvEnOBcn5LhZfXYyFrcqu+P2zGv8HzluQtOV71wb4Ievj0hhT5FtcKL5lwEtZH9+b9YunnPwzrkTuC5JPknNBctsNv44z31Q+3vwTwN9Zv2AfZD02AuOvMBjZegFwGNgG3OVi8a0E/gN2Zf2LcKX4rMY6LTnY+L1TwERgP7AH6OJK3zuMKxQbsxLHLuAxJ8c3FzgBpGHMEroDPYGeFt+/yVnx7ynIz1YqJIUQplxxzUEI4QIkOQghTElyEEKYkuQghDAlyUEIYUqSgxDClCQHIYQpSQ5CCFP/D3C0klUltwdDAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 288x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Fit a linear regression using mean squared error.\n",
    "regression = RegressionModel() \n",
    "params = regression.parameters()\n",
    "optimizer = torch.optim.Adam(params, lr = 0.01) \n",
    "\n",
    "for epoch in range(2000): \n",
    "\n",
    "    y_i = regression(x) \n",
    "    \n",
    "    # Hijacking the general loss to compute MSE.\n",
    "    loss = torch.mean(robust_loss_pytorch.general.lossfun(\n",
    "        y_i - y, alpha=torch.Tensor([2.]), scale=torch.Tensor([0.1])))\n",
    "\n",
    "    optimizer.zero_grad() \n",
    "    loss.backward() \n",
    "    optimizer.step() \n",
    "    if np.mod(epoch, 100) == 0:\n",
    "        print('{:<4}: loss={:03f}'.format(epoch, loss.data)) \n",
    "\n",
    "# It doesn't fit well.\n",
    "plot_regression(regression)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0   : loss=1.497499  alpha=1.004995  scale=0.993690\n",
      "100 : loss=0.365647  alpha=1.460623  scale=0.433759\n",
      "200 : loss=-0.307419  alpha=1.678272  scale=0.181545\n",
      "300 : loss=-0.440313  alpha=1.445860  scale=0.125634\n",
      "400 : loss=-0.747422  alpha=0.771877  scale=0.074583\n",
      "500 : loss=-1.213966  alpha=0.286796  scale=0.031820\n",
      "600 : loss=-1.439339  alpha=0.127493  scale=0.017698\n",
      "700 : loss=-1.485178  alpha=0.074832  scale=0.014162\n",
      "800 : loss=-1.498284  alpha=0.052338  scale=0.013124\n",
      "900 : loss=-1.503734  alpha=0.040308  scale=0.012733\n",
      "1000: loss=-1.507125  alpha=0.032830  scale=0.012530\n",
      "1100: loss=-1.509135  alpha=0.027691  scale=0.012406\n",
      "1200: loss=-1.510411  alpha=0.023921  scale=0.012322\n",
      "1300: loss=-1.511533  alpha=0.021025  scale=0.012260\n",
      "1400: loss=-1.512304  alpha=0.018723  scale=0.012211\n",
      "1500: loss=-1.512898  alpha=0.016844  scale=0.012176\n",
      "1600: loss=-1.513401  alpha=0.015280  scale=0.012141\n",
      "1700: loss=-1.513436  alpha=0.013955  scale=0.012122\n",
      "1800: loss=-1.514139  alpha=0.012819  scale=0.012093\n",
      "1900: loss=-1.514423  alpha=0.011832  scale=0.012073\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQcAAAD8CAYAAAB6iWHJAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3XlYlFX/x/H3l0UBRTC3EtdKM7LS5HHNNUsUU9Mst8xyadPysfRBLTPLXEjNp8wibc8ttwctpVxKfyQqampYlmmluC9oCcp2fn+ABsMAAw7MDPN9XVfXNTPcc883kA/nPufc54gxBqWUsuTh6AKUUs5Jw0EpZZWGg1LKKg0HpZRVGg5KKas0HJRSVmk4KKWs0nBQSlml4aCUssrLUR9cuXJlU6dOHUd9vFJua8eOHaeNMVUKOs5h4VCnTh3i4uIc9fFKuS0R+cOW4/SyQilllYaDUsoqDQellFUaDkopqzQclFJWaTgopazScFBKWaXhoJSrS0uDF1+EI0fseloNB6Vc2YUL0K0bTJ4MK1bY9dQOmyHpCCt3JRARvZ+jiclUD/RldKdb6NE4yOnOWZxcrV5l3cpdCXyyeBOTPxhHvTN/snf8VBqPGGHXz3CblsPKXQmMXb6XhMRkDJCQmMzY5XtZuSvBqc5ZnFytXmXdyl0JfPb2Mt59+xmCzp/ksQcn0v0P4Ytth+z6OW4TDhHR+0lOTc/xWnJqOhHR+53qnMXJ1epV1sXOnM+nn4whxdObngOms/rEQQ5/Fs4LL71m189xm8uKo4nJhXrdUecsTq5Wr7JgDMycyeufTWT3DfUZ0j2c/f/3GRd/XI9fg9bQMMyuH+c24VA90JcEK78E1QN9neqcxcnV6lX/9BGdPPMXEZvep8e2L9l4e1uevHsQh6MiuJywj4BW/Qho1ZcaFf3s+tluc1kxutMt+Hp75njN19uT0Z1ucapzFidXq9fdXekjunDiNPOXvkKPbV8S2fIhljw2nD8WhJNy4gCVu40h8O5++JXxsvvP0W1aDld65O3ZU18c5yxOrlavu4uI3k+l00eZv/QVbjybwOjOz/JxuYqcCR9IQAV/agybxV/+tYvt5yiO2iszJCTE6GIvSuXtgYEziVz2KmXSU3myx1iiT/3OuY0fUKZqXeZ8tJAhof8q0nlFZIcxJqSg49zmskIpl7J0KYsWjiWpjA89+01h9U+bOLdhHn71mlOt3zRmxJwp9iFoDQelnIkxMG0a9O7N37fdwUN9XiVmXSR/7/maCi0eonKPcDzK+JTIELTb9Dko5fRSUuDpp2H+fOjTh9Ph4fzZtTuXjx2lUtfnKX9b+xyHF/cQtE0tBxEJFZH9InJARMKtfL2WiGwUkV0iskdEuti/VKVKsXPnoHPnzGB46SW+HjSIZm3bkpGSzO1DZ+QKBij+IegCWw4i4gnMAe4FjgDbRSTKGLMv22EvAkuMMXNFJBj4CqhTDPUq5VIs72Vp36AKG38+lXO0KOAyhIXBb7/Bxx/z9oULjAwLIzg4mFWrVrHrrBdjl+/NMbu1JIagbWk5NAUOGGMOGmNSgEVAd4tjDFAh63EAcNR+JSrlmqzdy/JZ7J85ni+evZjLIU3hxAnS1qzhma1bGTFiBF26dCEmJobatWvTo3EQU3reTlCgLwIEBfoypeftxT4EbUufQxBwONvzI0Azi2MmAl+LyAigHNDR2olEZBgwDKBWrVqFrVUpl2LtXpbsuv60iRlfzuJkYFUqrFnBg+PGsW7dOsaMGcPrr7+Op+c/E9Z6NA4q8fkotoSDWHnNcnJEX+AjY8wMEWkBfCoiDY0xGTneZEwkEAmZ8xyKUrBSriLPDkNjeGbLEkZv/pRtNYIZ3H4wKQMGcOjQIT788EMGDRpUonXmxZZwOALUzPa8BrkvGwYDoQDGmC0i4gNUBk7ao0ilXFGgnzfnklJzvOadnsqUtW/z4I/rWRHcjmdva8/xpRMJ8CvD+vXrad26tYOqzc2WPoftQD0RqSsiZYA+QJTFMX8C9wCIyK2AD3DKnoUq5WosJx8HJP/Fp4tf4sEf1zOrVT+G1AjmyLJJ3HDDDWzbts2pggFsaDkYY9JEZDgQDXgCHxhj4kVkEhBnjIkCngfeF5F/k3nJMcg4al62Uk7ifPI/rYba547y4RcTCbpwkufC/s2Hx3/jr5gFNG7Zjo1frSQgIMCBlVpn0yQoY8xXZA5PZn9tQrbH+4BW9i1NKdd25Rb5kCPxRC6fDEC/ni+xYfdq/vplO//+97+JiIjI0fHoTHT6tFLFZHSnW3hw/yY+XzSec77+3N/9P6zeOJ+/Du4iMjKSmTNnOm0wgE6fVqp4GEOPqHn0WDmdnTfeSf/G3TkQNQ0fLyHq669p3z73jEdno+GglL1dvgyDB8Pnn8Ojj7K3VSt+e+YZbrrxRlatWkW9evUcXaFN9LJCKXs6fRo6doTPPyfj1VcZXbkyg4YNo23btmzZssVlggE0HJSyn19+gebNYft2kj/8kO5bt/LGjBlc37wbvzYeQdf3drnUNgB6WaGUPXz3HTzwAHh6cnzBAu6bOJH4ffuo1ukpyjbKXBX6yj4hgEsszactB6Wu1ccfw733wvXXs3PuXO586in+/PNPGgx8HZ9GOZeLd6V9QjQclCqqjIzMDWwHDYI2bVj07LO06N8ff39/YmNjSap6m9W3uco+IRoOShXFpUvQrx9Mnox5/HFebNKEvk89RcuWLdm6dSsNGjTIczEWV9knRMNBqcI6dQruuQcWLyZl0iR6nT3L5OnTGTp0KNHR0VSqVAlw/X1CtENSqcL46afMVZuOHeP03LncFxnJ7t27mTlzJiNHjkTknxUOXH2fEA0HpWy1fj306gU+PsS/8w4dx43j4sWLrFq1ii5drC+b6ohFWuxFLyuUssW8eRAaCjVqsOrFFwl5+ml8fHz4/vvv8wwGV6fhoFR+MjIgPByGDsW0b8/rYWF0GzGCkJAQtm3bRsOGDR1dYbHRywql8pKcDAMHwtKlpA4ezKPnz7Nw+nQGDRrEu+++S9myZR1dYbHScFDKmhMnoFs32L6d8y+/zL1ffkncjh1MmzaN0aNH5+h4LK00HJRbstxPIscoQnx85ojEqVP89sYbtJ05k8TERFauXEm3bt1sP4+L03BQbufKfhJXlo3Pcc/DqXjo3RvKlWPDxInc/9JLVK5cmZiYGO68807bz1MKAkI7JJXbsbafRHJqOr9MegO6dMHUqcNbAwZwz5gx3HnnnWzbti1XMOR3Hle5d6Ig2nJQpUJhmveW9zaIySD82494Ytty0jt14omAAOZHRNC/f3/mzZuHj4+PTecp6HVXo+GgXF5hm/dXFn4F8Em9xJurZxD6yxYWNwnl7fOJ/F90NK+99hrjxo1DRPIMnuznsTx/aaCXFcrlFbZ5f+Wehyp/n2XJgnDu/XUrr7TozYiEvezYvZulS5cyfvz4q8Fgud/l2OV7WbkrweXvnSiIthyUy7ryF93aX2/Iu3nfo3EQFX7ZR/ATL+CfdIFn2vXj49jlXFcxkM2bN9OkSZOrx+YXPDHhHa4eo6MVSjkJy0sJawL9vK1/Yc0aOgx5CBMQwCePPUnk7NncddddREVFUb169RyHFtSv4Mr3ThRELyuUSypoB2uAS9a+PmcOdO1Kxk038XyrVgx680169+7Npk2bcgUD5N1/UFr6FfKj4aBcki0jAsmpGf8s6JqeDiNHwvDhXL73XjqXK8esJUuYMGECCxcuxM/Pz+o5Snu/Qn70skK5pLxGCixFRO+nR72AzFWbVq3izCOP0GzzZhKOH2fhwoX06dMn3/e7+poM10LDQbmk0Z1uKbDPASD98GFo8wLs3k38M8/Q8tNP8fPz47vvvqNp06b5vtdyCHPWw43cIhSu0HBQLsnyLzpkbu+eXfCJg3y4fBIm4xL/GzKEXnPncscddxAVFUXNmjXzPX9pnxptC+1zUC6rR+MgYsI7cGhqGAG+OUcm7jmwlS8+H0MGwqudOvFAZCTdunVj8+bNBQYDlP6p0bbQcFClwvnk1MwHxvBY3P94f9lr/FbxBlr6V+LlZcsYO3Ysy5Yto3z58jadr7RPjbaFhoMqFaoH+uKZkc4r697l5fXvE137DlqnJPHH8d94btKbfFehIzeNW0OrqRts2pLOnYcwr9BwUKXC2FbV+XD5qzy680vm3HI3XY/9ysWUSwx8bT7RKbdYnf6cH3cewrxCw0G5vsOH6Tr8Ye7+fRfj7ujA8F++x6diNd5ZEs2vBBWp76BH4yCm9LydoEBfBAgK9GVKz9vdpjMSdLRCubq4OLj/fkxSEu+EhTElKoquXbuyYMEC/P39mbL5S6tvs6XvoDRPjbaFthyU61q5Etq0IcPbm2G33caIqCheeOEFVq5cib+/P6B9B9dCw0G5HmNgxgzo2ZNL9erRukwZPo6LY/78+URERODp+U9fgfYdFJ1eVijXkpoKI0bAe+9xsm1bGu3eTYqnJ+vWraNNmza5Dnfn6c/XSsNBuY7z5+Ghh+Drr/khNJSm33zDzfXrs2rVKm666aY83+bufQdFZdNlhYiEish+ETkgIuF5HPOQiOwTkXgRWWDfMpXb+/13aNUKs2EDi+69l8Zr19KhY0e2bNmSbzCooiswHETEE5gDdAaCgb4iEmxxTD1gLNDKGHMbMLIYalXuautWaNYMc+QI4++6i77ffMOzzz7L6tWrCQgIcHR1pZYtLYemwAFjzEFjTAqwCOhuccxQYI4x5hyAMeakfctUbmvpUmjXjtSyZelWpQoRO3fy7rvvMnv2bLy89Kq4ONkSDkHA4WzPj2S9ll19oL6IxIhIrIiEWjuRiAwTkTgRiTt16lTRKlbuwRiYNg169+bCzTdz+8WLxJw5w9q1a3niiSccXZ1bsCUcrG0KaHl3rBdQD2gH9AXmiUhgrjcZE2mMCTHGhFSpUqWwtSp3kZoKQ4dCeDgHmzUj6OefMZUrExsbyz333OPo6tyGLeFwBMh+j2sN4KiVY/5njEk1xhwC9pMZFkoVzrlzEBoK8+ezvkULbt66lWZt2xIbG0v9+vUdXZ1bsSUctgP1RKSuiJQB+gBRFsesBNoDiEhlMi8zDtqzUOUGDh6Eli0xmzfzZqNGdNyyhSeefJI1a9ZQsWJFR1fndgoMB2NMGjAciAZ+ApYYY+JFZJKIXNlyOBo4IyL7gI3AaGPMmeIqWpVC338PzZuTfuwYQ2rV4vk9e3jrrbeYO3cu3t55LDGvipUYY9l9UDJCQkJMXFycQz5bOZnFi+HRR0muUoUOSUn8lJ7O4sWL6dSpk6MrK5VEZIcxJqSg4/TeCuU4xsBrr0GfPpysU4ebT57kVMWKbNmyRYPBCWg4KMdISYHHHoOXXuKHhg2puX8/9Vq0YOvWrdx6662Ork6h91aoEmC5xPu45lUJmzgcvvuORcHB9P3xRwYPHsw777xDmTJlHF2uyqLhoIqV5RLvXod+47aIgaT/dZKX69Rhys8/M3PmTEaOHImItSk1ylE0HFSxyr7Ee8iReCKXT4aMdO7x8mXnmTNERUURFhbm4CqVNRoOqlhd3Y06fiPT1szmsE8FOqckc6hcRX6IWU/Dhg0dXKHKi4aDKlbVA3zo/eV8RsYsZHNANbqfP0FSUDCNB03SYHByGg6q+Fy+zKLNc6gZs4LPAqrx+PkTlGnYgdpdn2N8ryaOrk4VQMNBFY/Tp6FnT2pu3szMqjfw/MljVGw7iAadBjAmtIGuzOQCNByU/e3fD2FhZBw+zPDrruPjv8+zYsUKevTo4ejKVCFoOCi7WbkrgW/mLmHypxNIN4aewEE/P2LWr6dRo0aOLk8VkoaDsouVuxKInTiLWatnc9CnHF2SznMs6BbmzF+gweCidPq0unYZGSQ+H87UqBnE+PrTPOk8J25tS5WHX2f+zkRHV6eKSFsO6tpcugSPPcagjYv40C+AYRfPUe7u/lRu2QcRcast60sbbTmoojt5Ejp0gEWLGOdTjsGXkwjs9h8CW/W9OhXaQ8SmLe+V89GWgyqan36CsDDSExIY6OPDWv/y1Lp/HFTJuYdEujGMXb4XQIcvXYy2HFThrV+PadGCi6dP0zIlhf233caeXTt4c3gvPK3cPGXLlvfK+Wg4qMKZPx8TGspREYL/+otaDz7Ipk2bCArK3HIuI4+VxbTvwfVoOCjbZGRAeDgMGcJ2f3+CExN59KWXWLx4MX5+flcP0y3vSw/tc1BA7gVZcuxEnZQEAwfCsmUs8Pdn6MWLvP/55/Tr1y/XeUZ3uiXH+g2gW967Kg0HlWtBloTE5H86EW/whO7dMdu3M75sWT7w9WXDN9/QrFkzq+fSLe9LDw0HlWNBliuSU9NZ+kk03VdMIu3YMR4GDjZowLaoKGrVqpXv+XTL+9JBw0FZ7SxsfWgnc1ZO5bwX3JOSQs3u3fm/zz6jfPnyDqhQOYJ2SKpcnYX9fljDh19M5E8yuD0pifvCw1m+fLkGg5vRcFCM7nQLvt6eeGSkM27DfF6PnsM3Xt60Tk/jtY8+YsqUKXh46D8Vd6OXFYoejYPwTE6iwtDHaLsvhrc9vZng58eq1VHcfffdji5POYiGg4Jjx7j/2b5k/LSTkSKsv/UWdkRFUbduXUdXphxIw8Hd7dmD6dqVlOPH6WUMhIURs2ABFSpUcHRlysE0HNxM9slOvY7vYcqS10jMSOe+1FQ6jBrF9OnT8fT0dHSZygloOLiR7JOdHtm5monr3mOPhxfdyWDC++8zZMgQR5eonIiGgxuJiN7P5cspTNgwn8d3RLHKw4v+XmWp2meCBoPKRcen3EjiybO8t2Iyj++IYhbCgwHV8H90JmnXB+uCLCoXDQd3kZDA8kVj6XBgG08D4+o0otrAGXhXrA6g6y2oXPSywh3s2kVGWBg1Tp6kKxDT5H6qdhiCePzT8ajrLShL2nIo5WLf/IikZi04cuIkLTIy2Bn6DNd1fCJHMAAE+nk7qELlrLTlUFoZw54xk2j6xivs8PCgh1cZ0h4Yj28d63tI5LGAk3JjGg6lUVoajBzJHXPmsBxhUIUq+D84Ed9KNfJ8y/nk1BIsULkCDQcnl9cKTXmu3HThAubhh5G1a4kAXq7VkEo9xuHp65/v5+gybsqSTeEgIqHAbMATmGeMmZrHcQ8CXwD/MsbE2a1KN5XXCk1xf5xl2Y6EXK/7Hj9KxzGPQ3w8TwFRTbtSpc0QxDPnj1mA7FcRuoybsqbADkkR8QTmAJ2BYKCviARbOc4feBbYau8i3VVeKzQt3Ho41+s3Hd7PXQ/dR9K+fXQBbnvzTebOnYufT9kcx/l6e9K/eS2CAn0RICjQlyk9b9eVm1QutrQcmgIHjDEHAURkEdAd2Gdx3KvAdOAFu1boxvIaXky36D2875ctzI6azomMdLr6+fLK0qWEhoYCICK6nqMqElvCIQg4nO35ESDH6qIi0hioaYxZLSIaDnZSPdCXBCsB4SmSGRDGMGT7CsZt/IDtCL0DqrD2+40EB//TsNP1HFVR2TLPIfcWRtkuWUXEA5gFPF/giUSGiUiciMSdOnXK9ird1JUVmrLz9fakb7Oa+HsYXouew4sbP2AZ0DnoVl5bvi5HMCh1LWxpORwBamZ7XgM4mu25P9AQ+DZr89TrgSgR6WbZKWmMiQQiAUJCQnRkvQB5LvN+Y3mGv/4U1++OYQowu0kn3ntnLr2b6uIsyn5sCYftQD0RqQskAH2Aq7uZGGPOA5WvPBeRb4EXdLTCPnJdFvzxB6nNmlF5/34GA8FvvMGxUaOu7mqtlL0UGA7GmDQRGQ5EkzmU+YExJl5EJgFxxpio4i7SneS789S2baR27kzSuXP09/HhiSVLuP/++x1bsCq1bJrnYIz5CvjK4rUJeRzb7trLck/57jx1MJb0vn05kpbG4Ouv5821a7njjjscWa4q5XSGpBOxOq8hJY2jYydC9Dy2ApP/9S8WrV5N1apVHVKjch8aDk7Ecl6DV3oar619iz4/rmchMObODsz4b6QGgyoRGg5OJPu8hgqX/mbOsldpfSSeV4FZrR/Bv8VDvPzlr5Qp66NzF1Sx0/UcnMDKXQm0mrrhajDUTDzOso+eo9mReB719OKtB8ZRoeXDiAjJqem6apMqEdpycDDLTsi7En7ivSUv45mSRKhvAD8/PAm/ajfleI+u2qRKgoaDg2XvhOy67zveWD2DIyaDB6rU4dxDkyhT/rpc79Hbq1VJ0HBwsKOJyWAMz8QsZHTMAjYD/W9uDt1G4+VdNtfxenu1KikaDg5Wq7wXwz+fTO+fN/MZMKrFw/i2HnB1xmOgrzcikJiUqndVqhKl4eBIZ8/y+cIx1Ph5DxM9PJkTNopywW2vflmAH16+z3H1Kbem4eAoBw7wd7t2VElIYKBPOb556FXK3VA/xyHat6AcSYcyHcBs3kxyo0ZcSkjg6Xr1aLNiI4G1bs1xjPYtKEfTlkMR5XuDVD7SPvkE89hj/JmRwX/vu4//Ll9OuXLlqFytaOdTqrhoOBRBvjdI5fULbQxJ4eH4TZ/Ot0DMqFG8FRGBh4fH1fdpGChnopcVRZDXwq95zly8fJnz3brhN306n3p4cOyjjxg/Y8bVYFDKGWnLoQjymqFo9fXTpznXrh0V4+OZUq4c7b7+mhYtWxZzhUpdO/3TVQR5jSJYvm5+/pnEW2/FNz6e0TVr0i8+XoNBuQwNhyLIa+HX7KMLaevXk9SoESmnT/NSq1ZMiI+ndu3aJV2qUkWm4VAEPRoHMaXn7Tk2hunVJIiI6P3UDf+SV3o8h7n3Xv64fJkPhg5l6nff4e+f/3Z0Sjkb7XMoouyjC1dGLy6lpDLy67k898Ma1iF885/JTJs6zsGVKlU0Gg52EBG9n4zkZGZ98TI9/tzLfE9vXu49kdqBzR1dmlJFpuFgB5eOHufTT5+naeJxxvoF8smACLwr3qDrLiiXpuFwjdL27mV55FCqXk6ib9Ubiek3Be+y5QC9N0K5Ng2Ha/B3VBT06oVPWhrdgtvwS9jzeHhkjmLovRHK1eloRRGdmjqVst2783taGhtef51nPltAjevK67b2qtTQlkNhZWTw5yOPUGvBAjZ4eeG9YgV9u3YF8rmvQikXpOFQGMnJHGrdmro7drAoIICQ2FhubtDA0VUpVSw0HKywdjv2/dWEhCZNqH38OO/Vr8/DsbEEVqzo6FKVKjba52DhyoSmhMRkDJm3Y0e+9QUnbryJysePM69zZwbHx2swqFJPw8GC5e3YLX/cwMKPRsHlS8wa+BQft3mBei9G02rqBlbuSnBgpUoVL72ssJB94lLvjR8wZdty4sWDQV2f51Kte0jO+rpNC7wo5cK05WCheqAvYjIY/cXLRGxbzjdeZen9yAwuNGxfuAVelHJx2nLIZuWuBNIvXOC/84Zz/9kE3ilfiemDZuMbWIkUi2C4QqdIq9LKLcPB2mgEwBvzopkz/1kaXb7I6OoNWNJvKhXL+zKx221ERO+/utFtdjpFWpVWbhcOeS0O2+CPnSz4fALXZaTT/4772BI6AhGhXFmvq30K2d8HOkValW5u1+dgbXHYu75fxiefjscjI53u9wwjtvOzV7ejS0hMptXUDQC5FnjRKdKqNHO7loNlH8HDX73J63vXscfDi8G9X+FsnTtzvedK62JKz9uJCe9QUqUq5VBu03JYuSuBVlM3YLKeS3oa4Z+MYtredazxKc+gp94nud5deb5fRyaUu3GLcMg+6xHA52Iic+Y+zpPHfuHtSjUZ8dRHJAVUpVeTIILy6WDUkQnlTtwiHLL3M1Q+9isL3x1Mp4tneaFecyIGv4Mp40NqumHjz6eICe+QZ0DoyIRyJzaFg4iEish+ETkgIuFWvj5KRPaJyB4RWS8iTrUG+5W/+PXiv+V/n47i5rTL9G/Zl6U9X7za8Zj9OFuWnleqtCuwQ1JEPIE5wL3AEWC7iEQZY/ZlO2wXEGKMSRKRp4DpwMPFUbA1BW1qWz3Qlzr/e5f3Yr/grHjQo/sYfmvQOtd5rrQMrrxXN7ZV7syW0YqmwAFjzEEAEVkEdAeuhoMxZmO242OBAfYsMj8FbWprMjJ4YvVk+sdu4Afvsgwb8AZnqtbNdR7LloFubKvcnS2XFUHA4WzPj2S9lpfBwJprKaow8tvU9tLff/PNrbcyMGYD31e7gRf+s4izVesSFOjLgOa1dM6CUvmwpeUgVl4zVl5DRAYAIUDbPL4+DBgGUKtWLRtLzF9eIwjnDv3Gzpq9uC8xkS13383d337LRk9Pq8cqpXKzpeVwBKiZ7XkN4KjlQSLSERgPdDPGXLZ2ImNMpDEmxBgTUqVKlaLUm4u1EYTKB3eyaN7TNE1MZOcTT9Bi82ZEg0GpQrElHLYD9USkroiUAfoAUdkPEJHGwHtkBsNJ+5eZN8uRhZu3rSTqiwnUMekcfPtt7nr33ZIsR6lSo8DLCmNMmogMB6IBT+ADY0y8iEwC4owxUUAEUB74Imto8E9jTLdirDvHCEWArzdlvYSGX0wjMn4jZ729uRQdTf327YuzBKVKNZvurTDGfAV8ZfHahGyPO9q5rnxZjlCcu3CRRxb+h0nHfuW3664jKC4Ov7q5RySUUrZzyRuvso9QyF9nGP/xSIZePMfqKrXocmgfHuXKObhCpVyfS06fvjJCUTbhZ96LHMrQi+eYfXMznn3sbQ0GpezEJVsO1QN9ubwpio/W/pdbjGFU84dY3nZgvjdNKaUKx+XCwRhD570LGbpmHmXEg/73v8C24LZ674NSduZS4ZCamsr8Ll14ft06zvr48tyQmWwvV5MgvfdBKbtzmXA4c/o0S5s148mDB/kjKIiaO3awqFo1R5elVKnlEuFw6a+/2HDTTTxx4QKHmjen7saN4OPj6LKUKtVcIhx8fH1pWrMmh//1L+rOnw8eLjnIopRLcYnfspV7T/BI31m0qfYgraZ/q3tUKlUCxBirN1gWu5CQEBMXF1fgcZazIQG8PYTyPl4kJqXqQixKFZKI7DDGhBR0nNNfVlhbryE1w3AuKRXQDW2VKi5Of1lhy4rPumy8Uvbn9OFg64rPumy8Uvbl9OFgbSVoa3TZeKXsy+n7HCxXgg7w9eZiShqp6f90pOpcVQYmAAAFMElEQVTUaaXsz+nDAXKvBF3QUvRKqWvnEuFgSZeNV6r4OX2fg1LKMTQclFJWaTgopazScFBKWaXhoJSySsNBKWWVhoNSyioNB6WUVRoOSimrNByUUlZpOCilrHLJeyuUsgdH3sDnCjcPajgot2S5NmlJLjfoyM8uDL2sUG7J2tqkJbXcoCM/uzA0HJRbymtZwZJYbtCRn10YGg7KLeW1rGBJLDfoyM8uDA0H5ZasrU1aUssNOvKzC0M7JJVbslybtCRHDBz52YXh9DteKaXsy9Ydr/SyQilllYaDUsoqDQellFU2hYOIhIrIfhE5ICLhVr5eVkQWZ319q4jUsXehSqmSVWA4iIgnMAfoDAQDfUUk2OKwwcA5Y8zNwCxgmr0LVUqVLFtaDk2BA8aYg8aYFGAR0N3imO7Ax1mPlwL3iIjYr0ylVEmzJRyCgMPZnh/Jes3qMcaYNOA8UMkeBSqlHMOWcLDWArCcHGHLMYjIMBGJE5G4U6dO2VKfUspBbJkheQSome15DeBoHsccEREvIAA4a3kiY0wkEAkgIqdE5I9C1lsZOF3I95Qkre/aaH3Xxtb6attyMlvCYTtQT0TqAglAH6CfxTFRwKPAFuBBYIMpYOqlMaaKLQVmJyJxtszschSt79pofdfG3vUVGA7GmDQRGQ5EA57AB8aYeBGZBMQZY6KA+cCnInKAzBZDH3sVqJRyDJtuvDLGfAV8ZfHahGyPLwG97VuaUsqRXG2GZKSjCyiA1ndttL5rY9f6HHZXplLKublay0EpVUKcMhyc/V4OG+obJSL7RGSPiKwXEZuGjkqqvmzHPSgiRkRKrAfeltpE5KGs71+8iCwoqdpsqU9EaonIRhHZlfXz7VLC9X0gIidF5Mc8vi4i8t+s+veIyF1F/jBjjFP9R+aIyG/AjUAZYDcQbHHM08C7WY/7AIudrL72gF/W46ecrb6s4/yBTUAsEOIstQH1gF1AxaznVZ3pe0fmdf1TWY+Dgd9Lqr6sz2wD3AX8mMfXuwBryJyY2BzYWtTPcsaWg7Pfy1FgfcaYjcaYpKynsWROHCsptnz/AF4FpgOXnKy2ocAcY8w5AGPMSSerzwAVsh4HkHtCYLEyxmzCygTDbLoDn5hMsUCgiNxQlM9yxnBw9ns5bKkvu8FkJnlJKbA+EWkM1DTGrC7BusC27119oL6IxIhIrIiEllh1ttU3ERggIkfIHN4fUTKl2ayw/z7z5IwLzNrtXo5iYvNni8gAIARoW6wVWXysldeu1iciHmTeVj+opArKxpbvnReZlxbtyGxxbRaRhsaYxGKuDWyrry/wkTFmhoi0IHPyX0NjTEbxl2cTu/1uOGPLoTD3cpDfvRzFxJb6EJGOwHigmzHmcgnVBgXX5w80BL4Vkd/JvC6NKqFOSVt/tv8zxqQaYw4B+8kMi5JgS32DgSUAxpgtgA+Z9zQ4C5v+fdqkJDtTbOxw8QIOAnX5p1PoNotjniFnh+QSJ6uvMZkdW/Wc8ftncfy3lFyHpC3fu1Dg46zHlclsIldyovrWAIOyHt+a9YsnJfwzrkPeHZJh5OyQ3FbkzynJ/6lC/M93AX7J+gUbn/XaJDL/CkNmWn8BHAC2ATc6WX3rgBPAD1n/RTlTfRbHllg42Pi9E2AmsA/YC/Rxpu8dmSMUMVnB8QNwXwnXtxA4BqSS2UoYDDwJPJnt+zcnq/691/Kz1RmSSimrnLHPQSnlBDQclFJWaTgopazScFBKWaXhoJSySsNBKWWVhoNSyioNB6WUVf8PR0d1bfqE/5IAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 288x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Fit a linear regression, and the parameters of an adaptive loss.\n",
    "regression = RegressionModel() \n",
    "adaptive = robust_loss_pytorch.adaptive.AdaptiveLossFunction(\n",
    "    num_dims = 1, float_dtype=np.float32, device='cpu')\n",
    "params = list(regression.parameters()) + list(adaptive.parameters())\n",
    "optimizer = torch.optim.Adam(params, lr = 0.01) \n",
    "\n",
    "for epoch in range(2000): \n",
    "\n",
    "    y_i = regression(x) \n",
    "    \n",
    "    # Stealthily unsqueeze to an (n,1) matrix, and then compute the loss.\n",
    "    # A matrix with this shape corresponds to a loss where there's one shape+scale parameter\n",
    "    # per dimension (and there's only one dimension for this data).\n",
    "    loss = torch.mean(adaptive.lossfun((y_i - y)[:,None]))\n",
    "\n",
    "    optimizer.zero_grad() \n",
    "    loss.backward() \n",
    "    optimizer.step() \n",
    "    if np.mod(epoch, 100) == 0:\n",
    "        # You can see the alpha+scale parameters moving around.\n",
    "        print('{:<4}: loss={:03f}  alpha={:03f}  scale={:03f}'.format(\n",
    "            epoch, loss.data, adaptive.alpha()[0,0].data, adaptive.scale()[0,0].data)) \n",
    "\n",
    "# It fits!\n",
    "plot_regression(regression)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
